Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713cspades wants to merge 13 commits intoNVIDIA:mainfrom
Conversation
50da1dc to
925d022
Compare
Greptile SummaryThis PR adds Key changes and callouts:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant TEModule as TE Module (__init__)
participant SDM as set_device_mesh()
participant FSDP2 as fully_shard()
participant RP as reset_parameters()
participant FWD as forward()
participant DCP as Torch DCP
User->>TEModule: __init__(tp_mesh, weight_mesh)
TEModule->>TEModule: init_fp8_metadata()
TEModule->>SDM: set_device_mesh(tp_mesh, weight_mesh)
SDM->>SDM: _convert_param_to_dtensor_param()<br/>plain param → DTensor(Shard/Replicate)
SDM->>SDM: set amax_reduction_group<br/>on Float8CurrentScalingQuantizer
SDM-->>TEModule: params are now DTensors
TEModule->>RP: reset_parameters(defer_init=device=="meta")
RP->>RP: _set_tensor_parallel_attributes()
User->>FSDP2: fully_shard(model, mesh[dp_dims])
FSDP2->>FSDP2: detects DTensor Shard(dim=0)<br/>→ uses _StridedShard for DP-TP overlap
FSDP2-->>User: model params are FSDP-sharded DTensors
Note over User,FSDP2: If meta device: call reset_parameters() now
loop Training Step
FSDP2->>FWD: all-gather DTensor shards → TP-sharded DTensor
FWD->>FWD: _extract_trainable_tensor_from_dtensor()<br/>_ToLocalIdentity preserves object identity
FWD->>FWD: TE C++ kernels on local Tensor
FWD-->>FSDP2: grad → DTensor.grad via _ToLocalIdentity.backward
end
User->>DCP: save({"app": AppState(model, optimizer)})
DCP->>DCP: AppState.state_dict()<br/>evict _extra_state, clear empty optim states
DCP-->>User: checkpoint written
User->>DCP: load(state_dict, checkpoint_id)
DCP->>DCP: AppState.load_state_dict()<br/>set_state_dict(strict=False)
DCP-->>User: model restored to pre-save state
|
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
Outdated
Show resolved
Hide resolved
4ec2947 to
dbb9d14
Compare
fcdd5bd to
c912f5b
Compare
bc82f02 to
267f1df
Compare
|
/te-ci L1 pytorch |
f0b3cae to
af7362a
Compare
9435382 to
15df86f
Compare
|
/te-ci L1 pytorch |
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
…ess. Signed-off-by: Cory Ye <cye@nvidia.com>
… are still model parity tested. Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
7ea9ab6 to
7e0d3a9
Compare
Signed-off-by: Cory Ye <cye@nvidia.com>
f5579a2 to
82780a1
Compare
Summary
(H/F)SDP2 x TPstrided sharding, andDTensorFP8 parameters for Torch DCP checkpointing, across allTransformerEngineBaseModule(s).GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules undertransformer_engine.pytorch.modulesare supported.FusibleOperationsupport is also a WIP, except forLayerNormorRMSNormwhich are TE modules.DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we useDTensor-based TP on thetorch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to thetorch.nn.Embedding, which is why we do not need to callset_device_meshfor the LM head!Usage / Documentation
(
tp_meshandweight_meshcan also be passed inTEModule.__init__.)Details
DTensor Lifecycle in TransformerEngine
__init__metadevice with the appropriatetp_sizeand TP sharding strategy, e.g.parallel_modeandsequence_parallel.TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)DTensorwith appropriate TPplacement(s) based on the TP sharding strategy specified in__init__, usingtransformer_engine.pytorch.distributed._convert_param_to_dtensor_param.tp_meshis a 1-DDeviceMeshcontaining the TPProcessGroupthat will be registered with the TransformerEngine module.weight_meshis the 1-DDeviceMeshcontaining theProcessGroupthat shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes likeFloat8CurrentScaling.fully_shard(which responds to the TP placements) and prior toreset_parameters(defer_init=False), which quantizes parameters.__init__(tp_mesh, weight_mesh)for supported TransformerEngine modules.fully_shardshards the TransformerEngine model with FSDP2.fully_shardencounters TP sharding ondim=0, it will use a_StridedShardfor DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in theDeviceMeshandDTensor.placements. (SeeAppendixfor visualization of this sharding strategy.)reset_parametersis called if using meta device initialization.fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such asFusedAdammust be used to properly handle high-precision main weights.)Tensoris actually a TP-shardedDTensor, which deviates from the original FSDP2 paradigm where the all-gatheredTensoris fully-unsharded and theDTensorwrapping is discarded. To support theseDTensorcompute weights in TransformerEngine modules, we utilizetransformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensorto localize theDTensorand also inheritrequires_gradattribute from theDTensorparameter as the localTensorhas this un-set duringDTensor.from_local(Tensor)for FP8 parameters specifically!Tensorgradient is converted toDTensorand attached to theDTensor.gradattribute. Handled by DTensor <> Tensor Autograd conversion functions, and in the case ofFusibleOperation, casted during the backward implementation.QuantizedTensorStorageNone, we senduntyped_storage()to a default 1-byte storage that unblocks DCP checkpoint loading assertions using this as a definition for "emptiness". This is because a storage of 0 bytes is adata_ptr() = nullptrand breaks DCP.untyped_storageis not used anywhere in TransformerEngine, it may break code that usesStorageto figure out if a Tensor is empty or not, as nowQuantizedTensorstorage will always be a 1-byte storage even when both row and column data are not set. Those checks would instead need to compare the storage size against 1 byte instead of 0 bytes.Bugs
"shard"was the presumed weight sharding sub-mesh in theDTensor.device_mesh. Now, users can precisely specify their own custom weight-shardingDeviceMeshfor per-tensoramax_reduction_groupvia theset_device_mesh(weight_mesh)API.TransformerEngineBaseModule:self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}Testing
mainvs.cspades:cye/fsdp2-tp-dcpwith Megatron-LMmainon PyTorch25.11DelayedScalinghas DCP save/load disparity issues, i.e. on the scale of+/-1to theuint8parameter checkpoint!Appendix
_StridedShard- Using FSDP2 x TP Strided-ShardingWhen
redistribute'ing a global DTensor to(_StridedShard(dim=0, sf=2), Shard(dim=0)),DTensorwill perform the following steps:Shardplacements to the right of_StridedShard. (In the above example, since TP=2, the factor is 2.)[0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling_convert_param_to_dtensor_param!_StridedShard.[0] [1] [2] [3]and[4] [5] [6] [7][0 4] [1 5] [2 6] [3 7], which are assigned to the_StridedShardranks.[0 1] [2 3] [4 5] [6 7]!Shardplacement.[0] [4]/[1] [5]/[2] [6]/[3] [7], which are assigned to theShardranks.[0] [1]/[2] [3]/[4] [5]/[6] [7]!PyTorch also supports the inverse / un-sharding of this
redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)Type of change
Changes
Please list the changes introduced in this PR:
Checklist: